import torch
try:
    from .utils.resnet_blocks import _resnet, BasicBlock, Bottleneck
except:
    from utils.resnet_blocks import _resnet, BasicBlock, Bottleneck


def Speaker_Encoder(**kwargs):
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs)


if __name__ == "__main__":
    model = Speaker_Encoder()
    total = sum([param.nelement() for param in model.parameters()])
    print("total param: {:.3f}M".format(total/1e6))

    data = torch.randn(10, 64, 200)
    print(data.shape)
    outputs = model(data)
    print(outputs.shape)
 
